#!/usr/bin/env python
# -*- coding:utf-8 -*-
from __future__ import nested_scopes, generators, division, absolute_import, with_statement, print_function, \
    unicode_literals

import os
import argparse
import numpy as np
import yaml
import json
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

from const import DEFAULT_CONFIG, MODEL_PATH, STATIC_PATH, DEFAULT_DB
from util import load_model
from model.LSTM import load_lstm_model
from train.func import load_data, word2vec_train, data_set

__all__ = ['train']


def train(conf=None, update=True, last_id=None):
    if conf:
        _c = DEFAULT_CONFIG
        _c.update(conf)
        conf = _c
    else:
        conf = DEFAULT_CONFIG

    np.random.seed(1337)
    config = conf.get('fit')

    texts, labels, next_id = load_data(last_id, db=conf.get('db'))

    index_dict, word_vectors, combined = word2vec_train(texts, conf=conf.get('word2vec'), update=update,
                                                        **config.get('save'))
    nb_words, embedding_weights, x_train, y_train, x_test, y_test = data_set(index_dict, word_vectors, combined, labels)
    m_path = os.path.join(MODEL_PATH, '%s.yaml' % (config.get('save').get('model'),))
    w_path = os.path.join(MODEL_PATH, '%s.h5' % (config.get('save').get('model'),))
    if update and os.path.exists(m_path) and os.path.exists(w_path):
        model = load_model(m_path, w_path)
    else:
        model = load_lstm_model(nb_words, embedding_weights)
    model.fit(x_train, y_train, epochs=config.get('epoch'), batch_size=config.get('batch_size'),
              verbose=config.get('verbose'), shuffle=config.get('shuffle'),
              validation_split=config.get('validation_split'), validation_data=(x_test, y_test),
              callbacks=[EarlyStopping(patience=2), ReduceLROnPlateau()])
    score = model.evaluate(x_test, y_test, batch_size=config.get("batch_size"))
    model_y = model.to_yaml()
    if config.get('save') and config.get('save').get('save'):
        with open(m_path, 'w') as _fp:
            _fp.write(model_y)
        model.save_weights(w_path)
    return score, next_id


if __name__ == '__main__':
    # 解析命令行参数
    ap = argparse.ArgumentParser(description='a crawler model')
    ap.add_argument('-c', '--config',
                    action='store_const',
                    const=None,
                    help='set config path(default static/conf.yaml)')
    ap.add_argument('-n', '--renew',
                    action='store_const',
                    const=False,
                    help='build new model(not update).')
    args = ap.parse_args("".split())
    if args.config is None:
        cp = os.path.join(STATIC_PATH, 'conf.yaml')
    else:
        cp = os.path.abspath(args.config)

    # 解析配置文件
    c = DEFAULT_CONFIG
    if os.path.exists(cp):
        with open(cp, 'r') as fp:
            c.update(yaml.safe_load(fp))

    ip = os.path.join(STATIC_PATH, 'paladin.db')
    i = DEFAULT_DB
    if os.path.exists(ip):
        with open(ip, 'r') as fp:
            i.update(json.load(fp))

    s, nid = train(conf=c, update=not args.renew, last_id=i.get('last_id'))
    i['last_id'] = nid
    with open(ip, 'w') as fp:
        json.dumps(i)
    print('Score: %s', s)
